---
title: "Fit model to many surveys"
output: html_notebook
---


```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```

```{r setglobal, tidy=FALSE, echo=FALSE}
library(knitr)

opts_chunk$set(out.width=600, out.height=600)
library(tidyverse)
library(stringr)

options(mc.cores = parallel::detectCores())
library(rstan)
rstan_options(auto_write = TRUE)

library(tidybayes)

library(tictoc)
library(here)
```

```{r dirs}
root.dir <- here()

code.dir <- file.path(root.dir, 'code')

data.out.dir <- file.path(root.dir, "data", "sim")
out.dir <- file.path(root.dir, "out", "sim")

model.out.dir <- file.path(out.dir, 'rdestprobe')
```

```{r}
# load nr.censuses
nr.censuses <- readRDS(file=file.path(data.out.dir, "fb100_censuses_perfect.rds"))
# load num.cand.gps
num.cand.gps <- readRDS(file=file.path(data.out.dir, "fb100_svy_num_cand_gps.rds"))
# load simulated survey results
svy.sim.out <- readRDS(file=file.path(data.out.dir, "fb100_svy_sims.rds"))
```


Load model utilities

```{r}
source(file.path(code.dir, '90_model_utils.R'))
```

## Now fit the model to many surveys

```{r compile-model}
###########
## model_random_degree_estprobe estimates probe gp sizes
rdestprobehp_model <- 
  stan_model(file = file.path(code.dir, "92_model_random_degree_estprobe_hidden.stan"))
```

This next chunk takes a long time -- about 35 hours

```{r}
cur.schools <- unique(svy.sim.out$school)
cur.gp.set.idxes <- c(1)
cur.svy.idxes <- 1:10

tic("Fitting model to surveys")
#sink(file="degree_model.log")
svy_models <- svy.sim.out %>%
              filter(school %in% cur.schools,
                     gp.set.idx %in% cur.gp.set.idxes,
                     svy.index %in% cur.svy.idxes) %>%
              pmap(fit_stan_model_to_survey, 
                   model=rdestprobehp_model,
                   model_out_dir=model.out.dir,
                   estimate_probegp_size=TRUE,
                   # don't print iteration updates
                   refresh=0)
#sink(NULL)

saveRDS(svy_models,
     file=file.path(model.out.dir,  "svy_models_10perschool.rds"))
toc()

#svy_models <- readRDS(file=file.path(model.out.dir,  "svy_models_10perschool.rds"))
```

```{r}
svy_models_info <- imap_dfr(svy_models,
                            ~tibble(index=.y,
                                    school=.x$school,
                                    gp.set.idx=.x$gp.set.idx,
                                    svy.index=.x$svy.index))

saveRDS(svy_models_info,
     file=file.path(model.out.dir,  "svy_models_info_10perschool.rds"))
#svy_models_info <- readRDS(file=file.path(model.out.dir,  "svy_models_info_10perschool.rds"))
```


```{r}
svy_models_degree_post <- 
  imap_dfr(svy_models,
           ~tibble(index=.y,
                   school=.x$school,
                   gp.set.idx = .x$gp.set.idx,
                   svy.index = .x$svy.index,
                   post_mean_degree_median = .x$post_mean_degree_median$mean_degree,
                   post_mean_degree_lower = .x$post_mean_degree_median$.lower,
                   post_mean_degree_upper = .x$post_mean_degree_median$.upper))

saveRDS(svy_models_degree_post,
     file=file.path(model.out.dir,  "svy_models_degree_post_10perschool.rds"))
#svy_models_degree_post <- readRDS(file=file.path(model.out.dir,  "svy_models_degree_post_10perschool.rds"))
```

```{r}
true_probegp_sizes <- 
  map_dfr(nr.censuses,
          function(census) {
            cur.school <- census$school[1]
            
            all.gps <- num.cand.gps %>%
              filter(school == cur.school) %>%
              pull(candidate.names) %>% purrr::simplify()
            
            # calculate true size of probe groups
            cur.true.probegp.size <- census %>% 
              summarise_at(all.gps, sum) %>%
              tidyr::pivot_longer(everything(),
                                  names_to='group', 
                                  values_to='group_size') %>%
              mutate(school = cur.school,
                     school_N = nrow(census),
                     group_prev = group_size / school_N) %>%
              select(school, everything())
            
            return(cur.true.probegp.size)
          })
true_probegp_sizes
```

```{r}
post_ps <- imap_dfr(svy_models,
                     ~.x$post_p_median %>%
                        mutate(school = .x$school,
                               gp.set.idx = .x$gp.set.idx,
                               svy.index = .x$svy.index,
                               group = .x$gp.set))   %>%
  left_join(true_probegp_sizes, by=c('school', 'group'))

saveRDS(post_ps,
     file=file.path(model.out.dir,  "svy_models_p_post_10perschool.rds"))
#post_ps <- readRDS(file=file.path(model.out.dir,  "svy_models_p_post_10perschool.rds"))

post_ps
```

## Unpack and compare results

R-hats
------

```{r}
post_rhats <- map_dfr(svy_models,
                      ~ .x$fit_summ$summary %>% as_tibble(rownames='param') %>%
                        select(param, Rhat, everything()) %>%
                        mutate(school = .x$school,
                               gp.set.idx = .x$gp.set.idx,
                               svy.index = .x$svy.index))

saveRDS(post_rhats,
        file=file.path(model.out.dir,  "svy_models_rhat_10perschool.rds"))
#post_rhats <- readRDS(file=file.path(model.out.dir,  "svy_models_rhat_10perschool.rds"))

post_rhats
```

```{r}
summary(post_rhats$Rhat)
```


Average degrees
---------------

NB: the mean of a log-normal distribution w/ params mu and sigma
    is exp(mu + sigma2/2)
    (this is calculated as part of the model)


Join true degrees onto posterior degree estimates

```{r}
post_mean_degrees <- svy_models_degree_post %>%
  left_join(svy.sim.out %>%
              select(school, gp.set.idx, svy.index,
                     estimand.dbar, true.dbar, kp.dbar.hat, dbar.hat),
            by=c('school', 'gp.set.idx', 'svy.index')) 
post_mean_degrees
```


```{r}
cur.lim <- 1 

pmd_plot_df <- post_mean_degrees %>%
  mutate(rel_err_design = (dbar.hat-true.dbar)/true.dbar,
         rel_err_model = (post_mean_degree_median-true.dbar)/true.dbar,
         model_minus_design_re = rel_err_model - rel_err_design)

pmd_plot <- pmd_plot_df %>%
  ggplot(.) +
  geom_point(aes(x=rel_err_design, 
                 y=rel_err_model),
                 alpha=.3) +
  geom_abline(intercept=0, slope=1, color='blue') +
  ylim(-cur.lim, cur.lim) + xlim(-cur.lim, cur.lim) +
  xlab(str_wrap("Relative error in new design-based estimate", 30)) +
  ylab(str_wrap("Relative error in new model-based posterior median estimate", 30)) +
  theme_bw() +
  coord_equal() +
  NULL

ggsave(filename=file.path(model.out.dir, 'sim_model_design_meandegree_est.pdf'),
       plot=pmd_plot)
ggsave(filename=file.path(model.out.dir, 'sim_model_design_meandegree_est.png'),
       plot=pmd_plot)

glue::glue("Correlation is {corr}", 
           corr = with(post_mean_degrees, 
                       cor((post_mean_degree_median-true.dbar)/true.dbar,
                           (dbar.hat-true.dbar)/true.dbar)))

pmd_plot
```

```{r}
summary(pmd_plot_df$model_minus_design_re)
```

```{r}
t.test(pmd_plot_df$rel_err_design,
       pmd_plot_df$rel_err_model,
       paired=TRUE)
```
## Plot estimated and true probe group sizes

```{r}
cur.lim <- 0.15

post_ps_plot <- post_ps %>%
  ggplot(.) +
  geom_point(aes(x=group_prev,
                 y=p_g),
                 alpha=.3) +
  geom_abline(intercept=0, slope=1, color='blue') +
  ylim(0, cur.lim) + xlim(0, cur.lim) +
  theme_bw() +
  coord_equal() +
  xlab("True group prevalence") +
  ylab("Model-based estimate\nof group prevalence")
  NULL
  
ggsave(filename=file.path(model.out.dir, 'sim_model_pg_est.pdf'),
       plot=post_ps_plot)
ggsave(filename=file.path(model.out.dir, 'sim_model_pg_est.png'),
       plot=post_ps_plot)

post_ps_plot
```


```{r}
ps_errs <- post_ps %>%
  select(school, gp.set.idx, svy.index, p_g, group_prev) %>%
  mutate(rel_err = (p_g - group_prev)/group_prev)

summary(ps_errs)
```

Look at relative errors in estimating the size of the probe groups

```{r}
post_ps_gpset <- post_ps %>%
  select(school, gp.set.idx, svy.index, p_g, group_prev) %>%
  group_by(school, gp.set.idx, svy.index) %>%
  summarize(gp_set_est = sum(p_g),
            gp_set_prev = sum(group_prev)) %>%
  mutate(rel_err = (gp_set_est - gp_set_prev)/gp_set_prev,
         err = gp_set_est - gp_set_prev)

cur.lim <- 0.6

post_ps_gpset_plot <- post_ps_gpset %>%
  ggplot(.) +
  geom_point(aes(x=gp_set_prev,
                 y=gp_set_est),
                 alpha=.3) +
  geom_abline(intercept=0, slope=1, color='blue') +
  ylim(0, cur.lim) + xlim(0, cur.lim) +
  theme_bw() +
  coord_equal() +
  xlab("True group set prevalence") +
  ylab("Model-based estimate\nof group set prevalence\n(New model)")
  NULL
  
ggsave(filename=file.path(model.out.dir, 'sim_model_pgpset_est.pdf'),
       plot=post_ps_gpset_plot)
ggsave(filename=file.path(model.out.dir, 'sim_model_pgpset_est.png'),
       plot=post_ps_gpset_plot)


summary(post_ps_gpset)

post_ps_gpset_plot
```


```{r}
glue::glue("The average error in estimating the size of a group set (estimate - truth) is {mean(post_ps_gpset$err)}")
```


```{r}
glue::glue("The biggest group set is {round(max(post_ps_gpset$gp_set_prev)*100)}% of the population.")
glue::glue("The average group set is {round(mean(post_ps_gpset$gp_set_prev)*100)}% of the population.")
```

```{r}
tmp <- post_mean_degrees
glue::glue("The average true mean network size is {truedbar} (logged, this is {log(truedbar)}",
           truedbar = mean(tmp$true.dbar))
```



